import sys
from typing import Dict


sys.path.insert(0, "/workspace/geartrain_py")
sys.path.insert(0, "/workspace/gear_yolo")
sys.path.insert(0, "/workspace/gear-cv")
from torch import device, dtype
from torch._C import device, dtype

from geartrain.tokenization_utils import PreTrainedTokenizer


from geartrain.backends.base import Backend
from geartrain.image_processing_utils import BaseImageProcessor, BatchFeature
from geartrain.modelcard import ModelCard
from geartrain.pipelines.base import ArgumentHandler, Pipeline
from gearyolo import YOLOv8DetectPipeline
from gearyolo.detectHead import DetectHead
from gearcv.benchmark import benchmark
import cv2


class testPreprocess(BaseImageProcessor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def preprocess(self, images, **kwargs) -> BatchFeature:
        
        model_inputs = {
            "model_inputs":images
        }
        return BatchFeature(model_inputs)

class VideoYOLOv8DetectPipeline(Pipeline):
    def __init__(
        self,
        model: Backend = None,
        tokenizer: PreTrainedTokenizer | None = None,
        image_processor: BaseImageProcessor | None = None,
        modelcard: ModelCard | None = None,
        framework: str | None = None,
        task: str = "",
        args_parser: ArgumentHandler = None,
        device: int | device = None,
        torch_dtype: str | dtype | None = None,
        binary_output: bool = False,
        **kwargs
    ):
        model_path = kwargs["model_path"]
        model = YOLOv8DetectPipeline(
            model_path=model_path,
            framework=framework
            )
        image_processor = testPreprocess()
        super().__init__(
            model,
            tokenizer,
            image_processor,
            modelcard,
            framework,
            task,
            args_parser,
            device,
            torch_dtype,
            binary_output,
            **kwargs
        )
    def preprocess(self, inputs_dict, **kwargs) -> Dict:
        source = inputs_dict["model_inputs"]
        inputs = self.image_processor(images=source)
        return inputs
        
    
    def _forward(self, model_inputs):
        outputs  = self.model(model_inputs)
        return outputs
    
    def postprocess(self, model_outputs):
        return model_outputs
    def _sanitize_parameters(self, **kwargs):
        preprocess_params = {}
        if "timeout" in kwargs:
            preprocess_params["timeout"] = kwargs["timeout"]
            
        postprocess_kwargs = {}
        if "threshold" in kwargs:
            postprocess_kwargs["threshold"] = kwargs["threshold"]
        print(preprocess_params)
        return preprocess_params, {}, postprocess_kwargs
    
    
    

if __name__ == "__main__":
    model = VideoYOLOv8DetectPipeline(model_path = "/workspace/gear_yolo/models/yolov8n.onnx", framework="ort", timeout = 1)
    data = model({"model_inputs": "/workspace/gear_yolo/assets/bus.jpg"})
    print(data)